import copy
import math

import numpy as np
from numba import njit
from numba.experimental import jitclass
from sklearn.metrics import accuracy_score

from utils import plot_losses

from numba import int32, float32, boolean

spec = [
    ('losses', float32[:]),
    ('val_accuracies', float32[:]),
    ('train_accuracies', float32[:]),
    ('weights', float32[:]),
    ('bias', float32),
    ('learning_rate', float32),
    ('stochastic', boolean),
    ('batch_size', int32),
    ('epochs', int32),
    ('nexamples', int32),
    ('num_iter_in_epoch', int32),
    ('_x_train', float32[:, :]),
    ('_y_train', float32[:]),
    ('_x_val', float32[:, :]),
    ('_y_val', float32[:]),
    ('l1_weight', float32)
]


@jitclass(spec)
class LogisticRegression():
    def __init__(self, _x_train, _y_train, _x_val, _y_val, learning_rate, stochastic, epochs, batch_size, l1_weight):
        self.losses = np.zeros((epochs)).astype(float32)
        self.val_accuracies = np.zeros((epochs)).astype(float32)
        self.train_accuracies = np.zeros((epochs)).astype(float32)
        self.weights = np.random.rand(_x_train.shape[1]).astype(float32)
        self.bias = np.random.rand()
        self.learning_rate = learning_rate
        self.stochastic = stochastic
        self.batch_size = batch_size
        self.epochs = epochs
        self.nexamples = _x_train.shape[0]
        self.num_iter_in_epoch = self.nexamples  // self.batch_size if self.stochastic else 1
        self._x_train = _x_train.astype(float32)
        self._y_train = _y_train.astype(float32)
        self._x_val = _x_val.astype(float32)
        self._y_val = _y_val.astype(float32)
        self.l1_weight = l1_weight

    def fit(self):
        x_ = self._x_train
        y_ = self._y_train
        epochs = self.epochs
        last_loss = math.inf
        for epoch in range(epochs):
            epoch_losses = []
            for iter in range(self.num_iter_in_epoch):
                if self.stochastic:
                    indices = np.random.randint(x_.shape[0], size=self.batch_size)
                    x = x_[indices, :]
                    y = y_[indices, :]
                    # print("Running Stochastic Gradient Descent using " + str(nexamples) + " samples")
                else:
                    x = x_
                    y = y_

                wTx = np.matmul(x, self.weights) + self.bias
                pred = self._sigmoid(wTx)
                loss = self.compute_loss(y, pred)
                grad_w, grad_b = self.compute_gradients(x, y, pred)
                # Add the l1 components
                # grad_w = grad_w + self.l1_weight * self.l1_loss_grad(self.weights)
                # grad_b = grad_b + self.l1_weight * self.l1_loss_grad([self.bias])[0]
                self.update_model_parameters(grad_w, grad_b)
                epoch_losses.append(loss)
            pred_to_class_val = self.predict_classes(self._x_val)
            pred_to_class_train = self.predict_classes(self._x_train)
            probs_train = self.predict_probs(self._x_train)
            train_loss = self.compute_loss(self._y_train, probs_train)
            val_accuracy = self.accuracy_score(self._y_val, pred_to_class_val)
            train_accuracy = self.accuracy_score(self._y_train, pred_to_class_train)
            self.val_accuracies.append(val_accuracy)
            self.train_accuracies.append(train_accuracy)
            self.losses.append(train_loss)
            if abs(last_loss - train_loss) < 1e-5:
                # print(last_loss)
                # print(train_loss)
                print("Early Stopping")
                print("Epoch:" + str(epoch) + "\n Validation Accuracy: " + str(val_accuracy))
                print("Train -  Loss: " + str(train_loss) + " Accuracy: " + str(train_accuracy))
                break
            if epoch + 1 % (epochs // 50) == 0:
                print("Epoch:" + str(epoch) + "\n Validation Accuracy: " + str(val_accuracy))
                print("Train -  Loss: " + str(train_loss) + " Accuracy: " + str(train_accuracy))
            # print(f"This {train_loss}")
            last_loss = train_loss
        # plot_losses(train_loss)

    def accuracy_score(self, true_values, predictions):
        accuracy = (true_values == predictions).sum() / true_values.shape[0]
        return accuracy

    def compute_loss(self, y_true, y_pred):
        # binary cross entropy
        # Adding a small value to remove the log(0) error
        y_zero_loss = y_true * np.log(y_pred + 1e-9)
        y_one_loss = (1 - y_true) * np.log(1 - y_pred + 1e-9)
        mean_cross_entropy = -np.mean(y_zero_loss + y_one_loss)
        mean_cross_entropy_with_l1 = mean_cross_entropy + self.l1_weight * (np.sum(np.abs(self.weights)) + self.bias)
        return mean_cross_entropy_with_l1

    def compute_gradients(self, x, y_true, y_pred):
        # derivative of binary cross entropy
        n, k = x.shape
        y_pred = y_pred.reshape((n, 1))
        # print(y_pred.shape)
        # print(y_true.shape)

        difference = (np.subtract(y_pred, y_true)).reshape((n, 1))
        gradient_b = np.mean(difference)
        # print(x.shape)
        # print(difference.shape)
        gradients_w = np.matmul(x.transpose(), difference)
        gradients_w = np.array([np.mean(grad) for grad in gradients_w])
        return gradients_w, gradient_b

    def l1_loss_grad(self, w):
        return np.array([self.d_abs(each_w) for each_w in w])

    def d_abs(self, x):
        mask = (x >= 0) * 1.0
        mask2 = (x < 0) * -1.0
        return mask + mask2

    def update_model_parameters(self, grad_w, grad_b):
        self.weights = self.weights - self.learning_rate * grad_w
        self.bias = self.bias - self.learning_rate * grad_b

    def predict_classes(self, x):
        # print(self.weights.shape)
        # print(x.shape)
        probabilities = self.predict_probs(x)
        return np.array([1 if p > 0.5 else 0 for p in probabilities])

    def predict_probs(self, x):
        # print(self.weights.shape)
        # print(x.shape)
        wTx = np.matmul(x, self.weights) + self.bias
        probabilities = self._sigmoid(wTx)
        return probabilities

    def _sigmoid(self, x):
        return np.array([self._sigmoid_function(value) for value in x])

    def _sigmoid_function(self, x):
        if x >= 0:
            z = np.exp(-x)
            return 1 / (1 + z)
        else:
            z = np.exp(x)
            return z / (1 + z)

    def _transform_x(self, x):
        return np.array(x)

    def _transform_y(self, y):
        return np.array(y.reshape(y.shape[0], 1))
